import re
import uuid

from app.base.logger import setup_logger
from app.service.nebula_service import NebulaClient
from app.utils.content_processor import ContentProcessor
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import JSONResponse

from app.api.models import (
    ConnectionRequest,
    MethodSearchRequest,
    NodeRequest,
    NodeResponse,
    NodeDetailRequest,
    NodeDetail
)

# 使用新的日志配置
logger = setup_logger("api_endpoints")

router = APIRouter()

# 会话存储
nebula_sessions = {}


@router.post("/connect")
async def connect(request: ConnectionRequest, req: Request):
    """连接到NebulaGraph数据库"""
    logger.info(f"收到连接请求: IP={request.ip}, Port={request.port}, User={request.username}, Space={request.space_name}")
    
    # 验证IP格式
    ip = request.ip.strip()
    if not ip:
        logger.error("IP地址不能为空")
        raise HTTPException(status_code=400, detail="IP地址不能为空")
    
    # 验证端口格式
    try:
        port = int(request.port)
        if port <= 0 or port > 65535:
            logger.error(f"端口号无效: {port}")
            raise HTTPException(status_code=400, detail="端口号必须在1-65535范围内")
    except (ValueError, TypeError):
        logger.error(f"端口号必须是整数: {request.port}")
        raise HTTPException(status_code=400, detail="端口号必须是整数")
    
    # 验证空间名
    if not request.space_name or not re.match(r'^[a-zA-Z0-9_]+$', request.space_name):
        logger.error(f"无效的图空间名: {request.space_name}")
        raise HTTPException(status_code=400, detail="无效的图空间名")
    
    # 创建新的客户端实例
    client = NebulaClient()
    
    # 尝试连接
    success = client.connect(
        ip,
        port,
        request.username, 
        request.password,
        request.space_name
    )
    
    if not success:
        error_message = client.get_error_message()
        logger.error(f"连接数据库失败: IP={ip}, Port={port}, Space={request.space_name}, 错误: {error_message}")
        raise HTTPException(status_code=500, detail=f"连接数据库失败: {error_message}")
    
    # 生成会话ID
    session_id = str(uuid.uuid4())
    
    # 存储会话信息
    nebula_sessions[session_id] = {
        "client": client,
        "ip": ip,
        "port": port,
        "username": request.username,
        "space_name": request.space_name
    }
    
    logger.info(f"数据库连接成功: IP={ip}, Port={port}, Space={request.space_name}, 会话ID={session_id}")
    return {"status": "success", "message": "连接成功", "session_id": session_id}


@router.post("/disconnect")
async def disconnect(req: Request):
    """断开NebulaGraph数据库连接"""
    session_id = req.headers.get("X-Session-ID")
    logger.info(f"收到断开连接请求, 会话ID={session_id}")
    
    if session_id and session_id in nebula_sessions:
        client = nebula_sessions[session_id]["client"]
        client.disconnect()
        del nebula_sessions[session_id]
        return {"status": "success", "message": "断开连接成功"}
    else:
        logger.warning(f"断开连接失败: 无效的会话ID={session_id}")
        raise HTTPException(status_code=400, detail="无效的会话ID")


# 获取当前会话的客户端
def get_client(req: Request):
    session_id = req.headers.get("X-Session-ID")
    if not session_id or session_id not in nebula_sessions:
        logger.warning(f"无效的会话ID: {session_id}")
        raise HTTPException(status_code=401, detail="未连接到数据库或会话已过期")
    
    return nebula_sessions[session_id]["client"]


@router.post("/search")
async def search_methods(request: MethodSearchRequest, req: Request):
    """搜索方法节点"""
    client = get_client(req)
    logger.info(f"搜索方法: {request.method_name}")
    
    if not client.connected:
        logger.warning("尝试搜索但未连接到数据库")
        return JSONResponse(
            status_code=400,
            content={"status": "error", "message": "请先连接到数据库"}
        )
    
    try:
        result = client.search_method_by_name(request.method_name)
        logger.info(f"搜索结果: 找到 {len(result)} 个方法")
        
        return {
            "status": "success",
            "data": result
        }
    except Exception as e:
        logger.exception(f"搜索方法时发生错误: {str(e)}")
        return JSONResponse(
            status_code=500,
            content={"status": "error", "message": f"搜索失败: {str(e)}"}
        )


@router.post("/nodes", response_model=NodeResponse)
async def get_nodes(request: NodeRequest, req: Request):
    """根据查询类型获取节点及其关系"""
    
    client = get_client(req)
    logger.info(f"获取节点关系: 方法={request.method_full_name}, 查询类型={request.query_type}, 路径深度={request.path_depth}")
    
    if request.query_type == "self":
        # 获取单个节点，并添加边形成自环以满足D3.js的要求
        query = f"""
        MATCH (v:function)
        WHERE v.function.full_name == "{request.method_full_name}"
        RETURN id(v) as node_id, v.function.name, v.function.full_name, v.function.type, v.function.visibility
        """
        result = client.execute_query(query)
        
        nodes = []
        edges = []
        if result:
            for record in result.rows():
                values = record.values  # values是属性不是方法
                # 确保有足够的值
                if len(values) < 5:
                    continue
                
                try:
                    # 获取节点ID - 确保从id(v)获取
                    node_id = values[0].value.decode('utf-8') if isinstance(values[0].value, bytes) else values[0].value
                    
                    # 将查询结果转换为字典 - 不包含content字段
                    node = {
                        "id": node_id,
                        "properties": {
                            "name": values[1].value.decode('utf-8') if isinstance(values[1].value, bytes) else values[1].value,
                            "full_name": values[2].value.decode('utf-8') if isinstance(values[2].value, bytes) else values[2].value,
                            "type": values[3].value.decode('utf-8') if isinstance(values[3].value, bytes) else values[3].value,
                            "visibility": values[4].value.decode('utf-8') if isinstance(values[4].value, bytes) else values[4].value
                        }
                    }
                    nodes.append(node)
                    
                    # 添加自环边以便D3.js处理单节点情况
                    # 这个边不会在可视化中显示，但可以保持数据结构一致性
                    dummy_edge = {
                        "source": node_id,
                        "target": node_id,
                        "properties": {
                            "type": "self_reference"
                        }
                    }
                    edges.append(dummy_edge)
                except Exception as e:
                    logger.error(f"处理节点数据时出错: {str(e)}")
                    continue
        
        logger.info(f"获取单个节点结果: {len(nodes)} 个节点")
        return NodeResponse(nodes=nodes, edges=edges)
    
    elif request.query_type == "upstream":
        # 获取上游方法
        logger.info(f"获取上游方法: {request.method_full_name}, 路径深度: {request.path_depth}")
        result = client.get_upstream_methods(request.method_full_name, request.path_depth)
        logger.info(f"上游方法结果: {len(result['nodes'])} 个节点, {len(result['edges'])} 条边")
        return result
    
    elif request.query_type == "downstream":
        # 获取下游方法
        logger.info(f"获取下游方法: {request.method_full_name}, 路径深度: {request.path_depth}")
        result = client.get_downstream_methods(request.method_full_name, request.path_depth)
        logger.info(f"下游方法结果: {len(result['nodes'])} 个节点, {len(result['edges'])} 条边")
        return result
    
    else:
        logger.warning(f"无效的查询类型: {request.query_type}")
        raise HTTPException(status_code=400, detail="无效的查询类型")


@router.post("/node_detail", response_model=NodeDetail)
async def get_node_detail(request: NodeDetailRequest, req: Request):
    """获取节点详情，包括解压内容"""
    
    client = get_client(req)
    logger.info(f"获取节点详情: 节点ID={request.node_id}")
    
    # 处理特殊ID值
    node_id = request.node_id
    
    # 构建查询
    query = f"""
    MATCH (v:function)
    WHERE id(v) == "{node_id}"
    RETURN id(v), v.function.name, v.function.full_name, v.function.commit_status, v.function.content, v.function.visibility, v.function.line_start, v.function.line_end, v.function.complexity, v.function.is_static, v.function.is_constructor
    """
    
    result = client.execute_query(query)
    if not result or len(result.rows()) == 0:
        logger.warning(f"节点不存在: ID={node_id}")
        raise HTTPException(status_code=404, detail="节点不存在")
    
    # 提取节点属性
    raw_properties = {}
    for record in result.rows():
        values = record.values  # values是属性不是方法
        # 确保有足够的值
        if len(values) < 11:
            continue
            
        try:
            # 将查询结果转换为字典
            raw_properties = {
                "id": values[0].value.decode('utf-8') if isinstance(values[0].value, bytes) else values[0].value,
                "name": values[1].value.decode('utf-8') if isinstance(values[1].value, bytes) else values[1].value,
                "full_name": values[2].value.decode('utf-8') if isinstance(values[2].value, bytes) else values[2].value,
                "commit_status": values[3].value.decode('utf-8') if isinstance(values[3].value, bytes) else values[3].value,
                "visibility": values[5].value.decode('utf-8') if isinstance(values[5].value, bytes) else values[5].value,
                "line_start": int(values[6].value) if values[6].value not in (None, '', 'None') else None,
                "line_end": int(values[7].value) if values[7].value not in (None, '', 'None') else None,
                "complexity": int(values[8].value) if values[8].value not in (None, '', 'None') else None,
                "is_static": values[9].value == 'true' if values[9].value is not None else False,
                "is_constructor": values[10].value == 'true' if values[10].value is not None else False
            }
        except Exception as e:
            logger.error(f"处理节点数据时出错: {str(e)}")
            continue
    
    # 解压内容
    content = values[4].value.decode('utf-8') if isinstance(values[4].value, bytes) else values[4].value
    if content:
        logger.debug(f"解压节点内容: 节点ID={node_id}")
        content = ContentProcessor.decompress(content)
    
    logger.info(f"节点详情获取成功: 节点ID={node_id}, 名称={raw_properties.get('name', '')}")
    
    # 构建响应 - 确保与前端期望的格式兼容
    return NodeDetail(
        id=node_id,
        name=raw_properties.get("name", ""),
        full_name=raw_properties.get("full_name", ""),
        content=content,
        raw_properties=raw_properties
    )


@router.get("/connection_status")
async def check_connection_status(req: Request):
    """检查数据库连接状态"""
    session_id = req.headers.get("X-Session-ID")
    logger.info(f"检查连接状态, 会话ID={session_id}")
    
    if not session_id or session_id not in nebula_sessions:
        logger.warning(f"无效的会话ID: {session_id}")
        raise HTTPException(status_code=401, detail="未连接到数据库或会话已过期")
    
    client = nebula_sessions[session_id]["client"]
    if not client.connected:
        logger.warning("数据库连接已断开")
        # 移除无效会话
        del nebula_sessions[session_id]
        raise HTTPException(status_code=401, detail="数据库连接已断开")
    
    # 执行一个简单查询以验证连接
    try:
        # 简单的连接测试查询
        test_query = "YIELD 1"
        result = client.execute_query(test_query)
        if not result:
            logger.error("连接测试查询失败")
            # 移除无效会话
            del nebula_sessions[session_id]
            raise HTTPException(status_code=500, detail="数据库连接测试失败")
        
        logger.info("连接状态正常")
        return {"status": "connected", "message": "数据库连接正常"}
    except Exception as e:
        logger.exception(f"检查连接状态失败: {str(e)}")
        client.connected = False  # 标记为未连接
        # 移除无效会话
        del nebula_sessions[session_id]
        raise HTTPException(status_code=500, detail=f"检查连接状态失败: {str(e)}") 